from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union

import torch
from torch import nn


class Task(ABC):

    enable_post_process_per_sample_gradient: bool = False

    @abstractmethod
    def compute_train_loss(
        self,
        batch: Any,
        model: nn.Module,
        sample: bool = False,
    ) -> torch.Tensor:

        raise NotImplementedError(f"{self.__class__.__name__} must implement the `compute_train_loss` method.")

    @abstractmethod
    def compute_measurement(
        self,
        batch: Any,
        model: nn.Module,
    ) -> torch.Tensor:

        raise NotImplementedError(f"{self.__class__.__name__} must implement the `compute_measurement` method.")

    def get_influence_tracked_modules(self) -> Optional[List[str]]:

    def get_attention_mask(self, batch: Any) -> Optional[Union[Dict[str, torch.Tensor], torch.Tensor]]:

    def post_process_per_sample_gradient(self, module_name: str, gradient: torch.Tensor) -> torch.Tensor:

        del module_name
        return gradient
